import torch
import torch.nn as nn
import time
from tensorboardX import SummaryWriter
from tools.utils import *
from common import *
import torch.nn.init as init


class AlignNet(nn.Module):
    # AlignNet uses input_lf and aug_scale as inputs.
    def __init__(self,
                 refPos,
                 scale=2,
                 conv=default_conv,
                 padding_mode="zeros",
                 level_num=100,
                 level_step=0.1,
                 pad_size=12,
                 view_num=9):
        super(AlignNet, self).__init__()

        self.refPos = refPos
        self.level_num = level_num
        self.level_step = level_step
        self.view_num = view_num
        self.scale = scale
        self.padding_mode = padding_mode
        self.pad_size = pad_size
        self.padder = nn.ReflectionPad2d(self.pad_size)
        # used modules: imresizer and relu
        self.relu = nn.ReLU(inplace=True)
        self.resizer = bicubic_imresize()
        # convolutions
        # feature extraction
        self.feat_conv = conv(in_channels=self.view_num * self.view_num, out_channels=2, kernel_size=1)
        # disparity estimation
        self.disp_conv1 = conv(2 * self.level_num, 100, 7)
        self.disp_conv2 = conv(100, 100, 5)
        self.disp_conv3 = conv(100, 50, 3)
        self.disp_conv4 = conv(50, 1, 1)
        # VDSR
        self.residual_layer = self.make_layer(Conv_ReLU_Block, 18)
        self.input = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True)
        self.output = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=3, stride=1, padding=1, bias=True)
        # init
        # kaiming initilization
        # for m in self.modules():
        #     weights_init_kaiming(m)

    def VDSR_net(self, bic_lf):
        # bic_lf: [N, UV, H, W]
        N, _, H, W = bic_lf.shape
        bic_lf = bic_lf.view(-1, 1, H, W)
        out = self.relu(self.input(bic_lf))
        out = self.residual_layer(out)
        res = self.output(out)
        vdsr_lf = bic_lf + res
        vdsr_lf = vdsr_lf.view(N, self.view_num, -1, H, W)
        vdsr_ref = vdsr_lf[:, self.refPos[0], self.refPos[1], :, :].unsqueeze(1)
        vdsr_lf = vdsr_lf.view(N, -1, H, W)
        return vdsr_lf, vdsr_ref

    def make_layer(self, block, num_of_layer):
        layers = []
        for _ in range(num_of_layer):
            layers.append(block())
        return nn.Sequential(*layers)

    def PSV_generator(self, input_lf, aug_scale, range_angular, range_spatial_x, range_spatial_y):
        """
        PSV generator for parallel computation.
        :param input_lf:                        [B, U*V, h, w]
        :param aug_scale:                       float, scale value generated by scaling augmentation during training
        :param range_angular:                   [1, U], angular range, used for acceleration in function warp
        :param range_spatial_x:                 [1, w], spatial range
        :param range_spatial_y:                 [1, h], spatial range
        :return: PSV:                           [LB, UV, h, w]
        """
        input_lf = self.padder(input_lf)
        B, UV, h, w = input_lf.shape
        disparity_levels = []
        level_step = self.level_step * aug_scale # resize the level step for multi-scale generalization
        for level in range(self.level_num):
            disp_value = (level - (self.level_num - 1) / 2.0) * level_step
            disparity_levels.append(torch.ones([B, 2, h, w]) * disp_value)
        disparity_levels = torch.cat(disparity_levels, dim=0)  # [LB, 2, h, w]
        if input_lf.is_cuda:
            disparity_levels = disparity_levels.cuda()
        input_lf = input_lf.repeat(self.level_num, 1, 1, 1)  # [LB, UV, h, w]
        PSV = warp_to_ref_view_parallel_double_range(input_lf, disparity_levels, self.refPos,
                                                     arange_angular=range_angular,
                                                     arange_spatial_x=range_spatial_x,
                                                     arange_spatial_y=range_spatial_y,
                                                     padding_mode=self.padding_mode)  # [LB, UV, h, w]
        PSV = PSV[:, :, self.pad_size:-self.pad_size, self.pad_size:-self.pad_size]
        return PSV # [LB, UV, h, w]

    def set_zero_lr_AlignNet(self):
        """
        Freeze the parameters of AlignNet.
        """
        for m in self.feat_conv.parameters():
            m.requires_grad = False
        for m in self.disp_conv1.parameters():
            m.requires_grad = False
        for m in self.disp_conv2.parameters():
            m.requires_grad = False
        for m in self.disp_conv3.parameters():
            m.requires_grad = False
        for m in self.disp_conv4.parameters():
            m.requires_grad = False


    def set_zero_lr_VDSR(self):
        """
        Freeze the parameters of VDSR.
        """
        for m in self.residual_layer.parameters():
            m.requires_grad = False
        for m in self.input.parameters():
            m.requires_grad = False
        for m in self.output.parameters():
            m.requires_grad = False

    def forward(self, lf_input, aug_scale, range_angular, range_spatial_x_lr_pad, range_spatial_y_lr_pad,
                range_spatial_x_hr_pad, range_spatial_y_hr_pad):
        """
        The forward should concern about the augmentation scale.
        :param lf_input:                       [B, U*V, h, w]
        :param aug_scale:                      float, augmentation scale
        :param range_angular:                  [GpuNum, U], angular range with multiple GPUs
        :param range_spatial_x_lr_pad:         [GpuNum, w+2*pad]
        :param range_spatial_y_lr_pad:         [GpuNum, h+2*pad]
        :param range_spatial_x_hr_pad:         [GpuNum, W+2*pad]
        :param range_spatial_y_hr_pad:         [GpuNum, H+2*pad]
        :return: warped_ref_bic_view:          [B, 1, H, W]
        """
        range_angular = range_angular.squeeze()
        range_spatial_x_lr_pad = range_spatial_x_lr_pad.squeeze()
        range_spatial_y_lr_pad = range_spatial_y_lr_pad.squeeze()
        range_spatial_x_hr_pad = range_spatial_x_hr_pad.squeeze()
        range_spatial_y_hr_pad = range_spatial_y_hr_pad.squeeze()
        ### --- First, generate LR PSV
        # the aug_scale should be divided for LR PSV generation
        B, h, w = lf_input.shape[0], lf_input.shape[2], lf_input.shape[3]
        aug_scale /= self.scale

        PSV = self.PSV_generator(lf_input, aug_scale,
                                 range_angular,
                                 range_spatial_x_lr_pad,
                                 range_spatial_y_lr_pad) # [LB, UV, h, w]
        depth_feats = self.relu(self.feat_conv(PSV)) # [LB, 2, h, w]
        depth_feats = depth_feats.view(self.level_num, B, -1, h, w) # [L, B, 2, h, w]
        depth_feats = depth_feats.permute(1, 2, 0, 3, 4) # [B, 2, L, h, w]
        depth_feats = depth_feats.contiguous().view(B, -1, h, w) # [B, 2L, h, w]
        ### --- Second, disparity estimation
        feats = self.relu(self.disp_conv1(depth_feats))
        feats = self.relu(self.disp_conv2(feats))
        feats = self.relu(self.disp_conv3(feats))
        lr_disp = self.disp_conv4(feats)
        ### --- Third, SR
        bic_lf = self.resizer(lf_input, self.scale)  # [B, UV, H, W]
        vdsr_lf, vdsr_ref = self.VDSR_net(bic_lf)

        bic_disp = self.resizer(lr_disp, self.scale) * self.scale
        disparity = bic_disp.repeat(1, 2, 1, 1)
        ### --- Finally, padding, warpping and cropping
        vdsr_lf = self.padder(vdsr_lf)
        disparity = self.padder(disparity)

        aligned_vdsr_lf = warp_to_ref_view_parallel_double_range(vdsr_lf,
                                                               disparity,
                                                               self.refPos,
                                                               range_angular,
                                                               range_spatial_x_hr_pad,
                                                               range_spatial_y_hr_pad,
                                                               padding_mode=self.padding_mode)
        ### --- Crop and return
        aligned_vdsr_lf = aligned_vdsr_lf[:, :, self.pad_size:-self.pad_size,
                        self.pad_size:-self.pad_size]
        return aligned_vdsr_lf, vdsr_ref, lr_disp, bic_disp


class AlignNet_ForTest(nn.Module):
    # This version is used for inference to avoid OOM.
    def __init__(self,
                 refPos,
                 scale=2,
                 conv=default_conv,
                 padding_mode="zeros",
                 level_num=100,
                 level_step=0.1,
                 pad_size=12,
                 view_num=9):
        super(AlignNet_ForTest, self).__init__()

        self.refPos = refPos
        self.level_num = level_num
        self.level_step = level_step
        self.view_num = view_num
        self.scale = scale
        self.padding_mode = padding_mode
        self.pad_size = pad_size
        self.padder = nn.ReflectionPad2d(self.pad_size)
        # used modules: imresizer and relu
        self.relu = nn.ReLU(inplace=True)
        self.resizer = bicubic_imresize()
        # convolutions
        # feature extraction
        self.feat_conv = conv(in_channels=self.view_num * self.view_num, out_channels=2, kernel_size=1)
        # disparity estimation
        self.disp_conv1 = conv(2 * self.level_num, 100, 7)
        self.disp_conv2 = conv(100, 100, 5)
        self.disp_conv3 = conv(100, 50, 3)
        self.disp_conv4 = conv(50, 1, 1)

        # VDSR
        self.residual_layer = self.make_layer(Conv_ReLU_Block, 18)
        self.input = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True)
        self.output = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=3, stride=1, padding=1, bias=True)
        # init
        # kaiming initilization
        # for m in self.modules():
        #     weights_init_kaiming(m)

    def VDSR_net(self, bic_lf):
        N, UV, H, W = bic_lf.shape
        vdsr_lf = []
        vdsr_ref = None
        for i in range(UV):
            bic_sub = bic_lf[:, i, :, :].unsqueeze(1)
            out = self.relu(self.input(bic_sub))
            out = self.residual_layer(out)
            res = self.output(out)
            vdsr_sub = res + bic_sub
            if i == (self.view_num * self.refPos[0] + self.refPos[1]):
                vdsr_ref = vdsr_sub
            vdsr_lf.append(vdsr_sub)
        vdsr_lf = torch.cat(vdsr_lf, dim=1)
        return vdsr_lf, vdsr_ref

    def make_layer(self, block, num_of_layer):
        layers = []
        for _ in range(num_of_layer):
            layers.append(block())
        return nn.Sequential(*layers)

    def PSV_generator(self, input_lf, aug_scale, padding_mode="zeros"):
        input_lf = self.padder(input_lf)
        B, UV, h, w = input_lf.shape
        disparity_levels = []
        level_step = self.level_step * aug_scale  # resize the level step for multi-scale generalization
        for level in range(self.level_num):
            disp_value = (level - (self.level_num - 1) / 2.0) * level_step
            disparity_levels.append(torch.ones([B, 2, h, w]) * disp_value)
        disparity_levels = torch.cat(disparity_levels, dim=0) # [LB, 2, h, w]
        if input_lf.is_cuda:
            disparity_levels = disparity_levels.cuda()
        input_lf = input_lf.repeat(self.level_num, 1, 1, 1) # [LB, U*V, h, w]
        PSV = warp_to_ref_view_serial_no_range(input_lf, disparity_levels, self.refPos,
                                               padding_mode=padding_mode)  # [LB, U*V, h, w]
        PSV = PSV[:, :, self.pad_size:-self.pad_size, self.pad_size:-self.pad_size]
        return PSV

    def PSV_generator_serial(self, input_lf, aug_scale):
        """
        This function calculate plane-sweeping for each level to avoid OOM.
        :param input_lf:            [B, U*V, h, w]
        :param aug_scale:           float
        """
        input_lf = self.padder(input_lf)
        B, UV, h, w = input_lf.shape
        PSV = []
        level_step = self.level_step * aug_scale  # resize the level step for multi-scale generalization
        for level in range(self.level_num):
            disp_value = (level - (self.level_num - 1) / 2.0) * level_step
            disparity_level = torch.ones([B, 2, h, w]) * disp_value
            if input_lf.is_cuda:
                disparity_level = disparity_level.cuda()
            PSV_l = warp_to_ref_view_serial_no_range(input_lf, disparity_level, self.refPos,
                                                     padding_mode=self.padding_mode)  # [B, UV, h, w]
            PSV_l = PSV_l[:, :, self.pad_size:-self.pad_size, self.pad_size:-self.pad_size]
            PSV.append(PSV_l)  # [B, UV, h, w]
        PSV = torch.cat(PSV, dim=0)
        return PSV  # [LB, UV, h, w]

    def forward(self, lf_input, aug_scale, dset_resolution):
        """
        The forward should concern about the augmentation scale.
        :param lf_input: [B, U*V, h, w]
        :param aug_scale: a number less than 1
        :param dset_resolution: should be "low" or "high", it depends on the input resolution
        :return: warped_ref_bic_view: [B, 1, H, W]
        """

        ### --- First, generate LR PSV
        # the aug_scale should be halved for LR PSV generation
        B, h, w = lf_input.shape[0], lf_input.shape[2], lf_input.shape[3]
        aug_scale /= self.scale
        if dset_resolution == "high":
            PSV = self.PSV_generator_serial(lf_input, aug_scale)  # [LB, UV, h, w]
        elif dset_resolution == "low":
            PSV = self.PSV_generator(lf_input, aug_scale)
        else:
            raise Exception("Wrong dset_resolution")
        depth_feats = self.relu(self.feat_conv(PSV))  # [LB, 2, h, w]
        depth_feats = depth_feats.view(self.level_num, B, -1, h, w)  # [L, B, 2, h, w]
        depth_feats = depth_feats.permute(1, 2, 0, 3, 4)  # [B, 2, L, h, w]
        depth_feats = depth_feats.contiguous().view(B, -1, h, w)  # [B, 2L, h, w]
        ### --- Second, disparity estimation
        feats = self.relu(self.disp_conv1(depth_feats))
        feats = self.relu(self.disp_conv2(feats))
        feats = self.relu(self.disp_conv3(feats))
        lr_disp = self.disp_conv4(feats)
        ### --- Third, SR
        bic_lf = self.resizer(lf_input, self.scale)
        vdsr_lf, vdsr_ref = self.VDSR_net(bic_lf)

        bic_disp = self.resizer(lr_disp, self.scale) * self.scale
        disparity = bic_disp.repeat(1, 2, 1, 1)
        ### --- Finally, padding, warpping and cropping
        vdsr_lf = self.padder(vdsr_lf)
        disparity = self.padder(disparity)

        aligned_vdsr_lf = warp_to_ref_view_serial_no_range(vdsr_lf,
                                                         disparity,
                                                         self.refPos,
                                                         padding_mode=self.padding_mode)
        ### --- Crop and return
        aligned_vdsr_lf = aligned_vdsr_lf[:, :, self.pad_size:-self.pad_size,
                        self.pad_size:-self.pad_size]
        return aligned_vdsr_lf, vdsr_ref, lr_disp, bic_disp


class AlignWithAggreNet(nn.Module):
    # The complete network designed by divide-and-conquer strategy which contains
    # AlignNet and AggreNet as well as a pre-upsampler.
    def __init__(self,
                 refPos,
                 scale=2,
                 conv=default_conv,
                 padding_mode="zeros",
                 level_num=100,
                 level_step=0.1,
                 pad_size=12,
                 view_num=9):
        super(AlignWithAggreNet, self).__init__()

        print("Current network is the original shallow one")

        self.refPos = refPos
        self.level_num = level_num
        self.level_step = level_step
        self.view_num = view_num
        self.scale = scale
        self.padding_mode = padding_mode
        self.pad_size = pad_size
        self.padder = nn.ReflectionPad2d(self.pad_size)
        # used modules: imresizer and relu
        self.relu = nn.ReLU(inplace=True)
        self.resizer = bicubic_imresize()
        # convolutions
        # feature extraction
        self.feat_conv = conv(in_channels=self.view_num * self.view_num, out_channels=2, kernel_size=1)
        # disparity estimation
        self.disp_conv1 = conv(2 * self.level_num, 100, 7)
        self.disp_conv2 = conv(100, 100, 5)
        self.disp_conv3 = conv(100, 50, 3)
        self.disp_conv4 = conv(50, 1, 1)

        # VDSR
        self.residual_layer = self.make_layer(Conv_ReLU_Block, 18)
        self.input = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True)
        self.output = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=3, stride=1, padding=1, bias=True)

        # fusion net
        self.kernel_size = 7

        self.fusion_conv1 = conv(self.view_num * self.view_num, 32, self.kernel_size)
        self.fusion_conv2 = conv(32, 64, self.kernel_size)
        self.fusion_conv3 = conv(64, 32, self.kernel_size)
        self.fusion_conv4 = conv(32, 16, self.kernel_size)
        self.fusion_tail = conv(16, 1, self.kernel_size)
        # init
        # kaiming initilization
        # for m in self._modules:
        #     weights_init_kaiming(m)

    def VDSR_net(self, bic_lf):
        # bic_lf: [N, UV, H, W]
        N, _, H, W = bic_lf.shape
        bic_lf = bic_lf.view(-1, 1, H, W)
        out = self.relu(self.input(bic_lf))
        out = self.residual_layer(out)
        res = self.output(out)
        vdsr_lf = bic_lf + res
        vdsr_lf = vdsr_lf.view(N, self.view_num, -1, H, W)
        vdsr_ref = vdsr_lf[:, self.refPos[0], self.refPos[1], :, :].unsqueeze(1)
        vdsr_lf = vdsr_lf.view(N, -1, H, W)
        return vdsr_lf, vdsr_ref

    def make_layer(self, block, num_of_layer):
        layers = []
        for _ in range(num_of_layer):
            layers.append(block())
        return nn.Sequential(*layers)

    def PSV_generator(self, input_lf, aug_scale, range_angular, range_spatial_x, range_spatial_y):
        input_lf = self.padder(input_lf)
        B, UV, h, w = input_lf.shape
        disparity_levels = []
        level_step = self.level_step * aug_scale  # resize the level step for multi-scale generalization
        for level in range(self.level_num):
            disp_value = (level - (self.level_num - 1) / 2.0) * level_step
            disparity_levels.append(torch.ones([B, 2, h, w]) * disp_value)
        disparity_levels = torch.cat(disparity_levels, dim=0)  # [LB, 2, h, w]
        if input_lf.is_cuda:
            disparity_levels = disparity_levels.cuda()
        input_lf = input_lf.repeat(self.level_num, 1, 1, 1)  # [LB, UV, h, w]
        PSV = warp_to_ref_view_parallel_double_range(input_lf, disparity_levels, self.refPos,
                                                     arange_angular=range_angular,
                                                     arange_spatial_x=range_spatial_x,
                                                     arange_spatial_y=range_spatial_y,
                                                     padding_mode=self.padding_mode)  # [LB, UV, h, w]
        PSV = PSV[:, :, self.pad_size:-self.pad_size, self.pad_size:-self.pad_size]
        return PSV  # [LB, UV, h, w]

    def set_zero_lr_AlignNet(self):
        # for parameter freezing, optional
        for m in self.feat_conv.parameters():
            m.requires_grad = False
        for m in self.disp_conv1.parameters():
            m.requires_grad = False
        for m in self.disp_conv2.parameters():
            m.requires_grad = False
        for m in self.disp_conv3.parameters():
            m.requires_grad = False
        for m in self.disp_conv4.parameters():
            m.requires_grad = False

    def free_AlignNet(self):
        # for parameter freezing, optional
        for m in self.feat_conv.parameters():
            m.requires_grad = True
        for m in self.disp_conv1.parameters():
            m.requires_grad = True
        for m in self.disp_conv2.parameters():
            m.requires_grad = True
        for m in self.disp_conv3.parameters():
            m.requires_grad = True
        for m in self.disp_conv4.parameters():
            m.requires_grad = True

    def set_zero_lr_VDSR(self):
        for m in self.residual_layer.parameters():
            m.requires_grad = False
        for m in self.input.parameters():
            m.requires_grad = False
        for m in self.output.parameters():
            m.requires_grad = False

    def forward(self, lf_input, aug_scale, range_angular,
                range_spatial_x_lr_pad,
                range_spatial_y_lr_pad,
                range_spatial_x_hr_pad,
                range_spatial_y_hr_pad):
        """
        The forward should concern about the augmentation scale.
        :param lf_input: [B, U*V, h, w]
        :param aug_scale: a number less than 1
        :param range_angular: [GpuNum, U]
        :param range_spatial_x_lr_pad: [GpuNum, w+2*pad]
        :param range_spatial_y_lr_pad: [GpuNum, h+2*pad]
        :param range_spatial_x_hr_pad: [GpuNum, W+2*pad]
        :param range_spatial_y_hr_pad: [GpuNum, H+2*pad]
        :return: warped_ref_bic_view: [B, 1, H, W]
        """

        range_angular = range_angular.squeeze()
        range_spatial_x_lr_pad = range_spatial_x_lr_pad.squeeze()
        range_spatial_y_lr_pad = range_spatial_y_lr_pad.squeeze()
        range_spatial_x_hr_pad = range_spatial_x_hr_pad.squeeze()
        range_spatial_y_hr_pad = range_spatial_y_hr_pad.squeeze()

        ### --- First, generate LR PSV
        # the aug_scale should be halved for LR PSV generation
        B, h, w = lf_input.shape[0], lf_input.shape[2], lf_input.shape[3]
        aug_scale /= self.scale
        PSV = self.PSV_generator(lf_input, aug_scale,
                                 range_angular,
                                 range_spatial_x_lr_pad,
                                 range_spatial_y_lr_pad)  # [LB, UV, h, w]
        depth_feats = self.relu(self.feat_conv(PSV))  # [LB, 2, h, w]
        depth_feats = depth_feats.view(self.level_num, B, -1, h, w)  # [L, B, 2, h, w]
        depth_feats = depth_feats.permute(1, 2, 0, 3, 4)  # [B, 2, L, h, w]
        depth_feats = depth_feats.contiguous().view(B, -1, h, w)  # [B, 2L, h, w]
        ### --- Second, disparity estimation
        feats = self.relu(self.disp_conv1(depth_feats))
        feats = self.relu(self.disp_conv2(feats))
        feats = self.relu(self.disp_conv3(feats))
        lr_disp = self.disp_conv4(feats)
        # disparity = disparity.repeat(1, 2, 1, 1)
        ### --- Third, SR
        bic_lf = self.resizer(lf_input, self.scale)  # [B, UV, H, W]
        vdsr_lf, vdsr_ref = self.VDSR_net(bic_lf)

        bic_disp = self.resizer(lr_disp, self.scale) * self.scale
        disparity = bic_disp.repeat(1, 2, 1, 1)
        ### --- Finally, padding, warpping and cropping
        vdsr_lf = self.padder(vdsr_lf)
        disparity = self.padder(disparity)

        aligned_vdsr_lf = warp_to_ref_view_parallel_double_range(vdsr_lf,
                                                               disparity,
                                                               self.refPos,
                                                               range_angular,
                                                               range_spatial_x_hr_pad,
                                                               range_spatial_y_hr_pad,
                                                               padding_mode=self.padding_mode)
        ### --- Crop and return
        aligned_vdsr_lf = aligned_vdsr_lf[:, :, self.pad_size:-self.pad_size,
                        self.pad_size:-self.pad_size]
        ### --- Final finally, Fusion
        feat = self.relu(self.fusion_conv1(aligned_vdsr_lf))
        feat = self.relu(self.fusion_conv2(feat))
        feat = self.relu(self.fusion_conv3(feat))
        feat = self.relu(self.fusion_conv4(feat))
        sr_ref = self.fusion_tail(feat) + vdsr_ref
        return aligned_vdsr_lf, vdsr_ref, sr_ref, lr_disp, bic_disp


class AlignWithAggreNet_ForTest(nn.Module):
    def __init__(self,
                 refPos,
                 scale=2,
                 conv=default_conv,
                 padding_mode="zeros",
                 level_num=100,
                 level_step=0.1,
                 pad_size=12,
                 view_num=9):
        super(AlignWithAggreNet_ForTest, self).__init__()

        self.refPos = refPos
        self.level_num = level_num
        self.level_step = level_step
        self.view_num = view_num
        self.scale = scale
        self.padding_mode = padding_mode
        self.pad_size = pad_size
        self.padder = nn.ReflectionPad2d(self.pad_size)
        # used modules: imresizer and relu
        self.relu = nn.ReLU(inplace=True)
        self.resizer = bicubic_imresize()
        # convolutions
        # feature extraction
        self.feat_conv = conv(in_channels=self.view_num * self.view_num, out_channels=2, kernel_size=1)
        # disparity estimation
        self.disp_conv1 = conv(2 * self.level_num, 100, 7)
        self.disp_conv2 = conv(100, 100, 5)
        self.disp_conv3 = conv(100, 50, 3)
        self.disp_conv4 = conv(50, 1, 1)

        # VDSR
        self.residual_layer = self.make_layer(Conv_ReLU_Block, 18)
        self.input = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True)
        self.output = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=3, stride=1, padding=1, bias=True)

        # fusion net
        self.kernel_size = 7

        self.fusion_conv1 = conv(self.view_num * self.view_num, 32, self.kernel_size)
        self.fusion_conv2 = conv(32, 64, self.kernel_size)
        self.fusion_conv3 = conv(64, 32, self.kernel_size)
        self.fusion_conv4 = conv(32, 16, self.kernel_size)
        self.fusion_tail = conv(16, 1, self.kernel_size)
        # init
        # kaiming initilization
        # for m in self.modules():
        #     weights_init_kaiming(m)

    def VDSR_net(self, bic_lf):
        N, UV, H, W = bic_lf.shape
        vdsr_lf = []
        vdsr_ref = None
        for i in range(UV):
            bic_sub = bic_lf[:, i, :, :].unsqueeze(1)
            out = self.relu(self.input(bic_sub))
            out = self.residual_layer(out)
            res = self.output(out)
            vdsr_sub = res + bic_sub
            if i == (self.view_num * self.refPos[0] + self.refPos[1]):
                vdsr_ref = vdsr_sub
            vdsr_lf.append(vdsr_sub)
        vdsr_lf = torch.cat(vdsr_lf, dim=1)
        return vdsr_lf, vdsr_ref

    def make_layer(self, block, num_of_layer):
        layers = []
        for _ in range(num_of_layer):
            layers.append(block())
        return nn.Sequential(*layers)

    def PSV_generator(self, input_lf, aug_scale):
        input_lf = self.padder(input_lf)
        B, UV, h, w = input_lf.shape
        disparity_levels = []
        level_step = self.level_step * aug_scale  # resize the level step for multi-scale generalization
        for level in range(self.level_num):
            disp_value = (level - (self.level_num - 1) / 2.0) * level_step
            disparity_levels.append(torch.ones([B, 2, h, w]) * disp_value)
        disparity_levels = torch.cat(disparity_levels, dim=0)  # [LB, 2, h, w]
        if input_lf.is_cuda:
            disparity_levels = disparity_levels.cuda()
        input_lf = input_lf.repeat(self.level_num, 1, 1, 1)  # [LB, UV, h, w]
        PSV = warp_to_ref_view_serial_no_range(input_lf, disparity_levels, self.refPos,
                                               padding_mode=self.padding_mode)  # [LB, UV, h, w]
        PSV = PSV[:, :, self.pad_size:-self.pad_size, self.pad_size:-self.pad_size]
        return PSV  # [LB, UV, h, w]

    def PSV_generator_serial(self, input_lf, aug_scale):
        # input_lf: [B, U*V, h, w]
        input_lf = self.padder(input_lf)
        B, UV, h, w = input_lf.shape
        PSV = []
        level_step = self.level_step * aug_scale  # resize the level step for multi-scale generalization
        for level in range(self.level_num):
            disp_value = (level - (self.level_num - 1) / 2.0) * level_step
            disparity_level = torch.ones([B, 2, h, w]) * disp_value
            if input_lf.is_cuda:
                disparity_level = disparity_level.cuda()
            PSV_l = warp_to_ref_view_serial_no_range(input_lf, disparity_level, self.refPos,
                                                     padding_mode=self.padding_mode)  # [B, UV, h, w]
            PSV_l = PSV_l[:, :, self.pad_size:-self.pad_size, self.pad_size:-self.pad_size]
            PSV.append(PSV_l)  # [B, UV, h, w]
        PSV = torch.cat(PSV, dim=0)
        return PSV  # [LB, UV, h, w]

    def forward(self, lf_input, aug_scale, dset_resolution):
        """
        The forward should concern about the augmentation scale.
        :param lf_input: [B, U*V, h, w]
        :param aug_scale: a number less than 1
        :param dset_resolution: "high" for HCI1 and "low" for others like Stanford
        :return: warped_ref_bic_view: [B, 1, H, W]
        """

        ### --- First, generate LR PSV
        # the aug_scale should be halved for LR PSV generation
        B, h, w = lf_input.shape[0], lf_input.shape[2], lf_input.shape[3]
        aug_scale /= self.scale
        if dset_resolution == "high":
            PSV = self.PSV_generator_serial(lf_input, aug_scale)  # [LB, UV, h, w]
        elif dset_resolution == "low":
            PSV = self.PSV_generator(lf_input, aug_scale)
        else:
            raise Exception("Wrong dset_resolution")
        # PSV = self.PSV_generator_serial(lf_input, aug_scale) # [LB, UV, h, w]
        depth_feats = self.relu(self.feat_conv(PSV))  # [LB, 2, h, w]
        depth_feats = depth_feats.view(self.level_num, B, -1, h, w)  # [L, B, 2, h, w]
        depth_feats = depth_feats.permute(1, 2, 0, 3, 4)  # [B, 2, L, h, w]
        depth_feats = depth_feats.contiguous().view(B, -1, h, w)  # [B, 2L, h, w]
        ### --- Second, disparity estimation
        feats = self.relu(self.disp_conv1(depth_feats))
        feats = self.relu(self.disp_conv2(feats))
        feats = self.relu(self.disp_conv3(feats))
        lr_disp = self.disp_conv4(feats)
        # disparity = disparity.repeat(1, 2, 1, 1)
        ### --- Third, SR
        bic_lf = self.resizer(lf_input, self.scale)  # [B, UV, H, W]
        vdsr_lf, vdsr_ref = self.VDSR_net(bic_lf)

        bic_disp = self.resizer(lr_disp, self.scale) * self.scale
        disparity = bic_disp.repeat(1, 2, 1, 1)
        ### --- Finally, padding, warpping and cropping
        vdsr_lf = self.padder(vdsr_lf)
        disparity = self.padder(disparity)

        aligned_vdsr_lf = warp_to_ref_view_serial_no_range(vdsr_lf,
                                                         disparity,
                                                         self.refPos,
                                                         padding_mode=self.padding_mode)
        ### --- Crop and return
        aligned_vdsr_lf = aligned_vdsr_lf[:, :, self.pad_size:-self.pad_size,
                        self.pad_size:-self.pad_size]
        ### --- Final finally, Fusion
        feat = self.relu(self.fusion_conv1(aligned_vdsr_lf))
        feat = self.relu(self.fusion_conv2(feat))
        feat = self.relu(self.fusion_conv3(feat))
        feat = self.relu(self.fusion_conv4(feat))
        sr_ref = self.fusion_tail(feat) + vdsr_ref
        return aligned_vdsr_lf, vdsr_ref, sr_ref, lr_disp, bic_disp